"""
直接用 Conv1d + AdaptiveAvgPool1d 暴力解
准确率只有 30% 多
"""

# %%
import torch
from torch import nn, Tensor, optim
from torch.autograd import Variable
import torch.nn.functional as F
from typing import (
    TypeVar, Type, Union, Optional, Any,
    List, Dict, Tuple, Callable, NamedTuple
)

import numpy as np

import random
import time
import os
import copy
import re
import logging
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
import itertools

from utils import Args, D, timeit

logger = logging.getLogger(__name__)
logging.basicConfig(
    level=10, format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')


def get_args() -> Args:
    """
    获取参数
    """
    default_data_dir = os.path.join(os.path.dirname(
        __file__), "dataset/traffic")
    return Args([
        D("batchSize", int, 32),
        D("learningRate", float, 1e-3),
        D("numEpochs", int, 5000),
        D("dataDir", str, default_data_dir),
        D("saveDir", str, None),
        D("nClass", int, 50),
    ])


def inf_loop() -> None:
    while True:
        yield None


class Sample(NamedTuple):
    """
    数据集中的sample，由数据和标签构成
    tag 表示该数据来源，例如:
    训练集下的 49-15 sample 的 tag 为 train-49-15
    """
    data: List[Tuple[float, int]]
    label: int
    tag: str

    def to_tensor_sample(self) -> Tuple[Tensor, int]:
        """
        转为 (Tensor,int)，数据部分为 (Channel,L) 格式
        时间和上下行流量标签各占一个 channel
        :return:
        """
        _data = Tensor(self.data).t()
        _label = self.label
        return _data, _label


class RawDataSet(NamedTuple):
    """
    原始数据集
    """
    train: List[Sample]
    test: List[Sample]

    def train_data_loader(self, batch_size: int = 1,
                          epochs: Optional[int] = 1) -> List[Tuple[Tensor, Tensor]]:
        return self.data_loader(self.train, batch_size, epochs)

    def test_data_loader(self, batch_size: int = 1,
                         epochs: Optional[int] = 1) -> List[Tuple[Tensor, Tensor]]:
        return self.data_loader(self.test, batch_size, epochs)

    @classmethod
    def data_loader(cls, _dataset: List[Sample],
                    batch_size: int = 1,
                    epochs: Optional[int] = 1) -> List[Tuple[Tensor, Tensor]]:
        """
        生成 mini_batch 数据
        因为每个sample数据尺寸不同，所以不能直接 torch.stack
        :param _dataset:
        :param batch_size:
        :param epochs:
        :return:
        """
        assert batch_size >= 1
        _iter = inf_loop() if epochs is None else range(epochs)

        def to_res_batch(batch: List[Tuple[Tensor, int]]) -> List[Tuple[Tensor, Tensor]]:
            res = []
            for _data, _label in batch:
                res.append((
                    torch.unsqueeze(_data, 0),
                    torch.LongTensor((_label,)),
                ))
            return res

        for _ in _iter:
            batch = []
            for i, sample in enumerate(_dataset):
                batch.append(sample.to_tensor_sample())
                if len(batch) >= batch_size:
                    res_batch = to_res_batch(batch)
                    batch.clear()
                    yield res_batch
            if len(batch) >= 0:
                res_batch = to_res_batch(batch)
                batch.clear()
                yield res_batch


def read_data(data_dir: str, max_workers: int = 12,
              num_train: Optional[int] = None, num_test: Optional[int] = None) -> RawDataSet:
    """
    读取原始数据
    """
    train_dir = os.path.join(data_dir, "defence")
    test_dir = os.path.join(data_dir, "undefence")
    train_files = os.listdir(train_dir)
    test_files = os.listdir(test_dir)
    file_name_pattern = re.compile(r"(\d+?)\-(\d+?)")

    def fn_train_tag(s):
        return f"train-{s}"

    def fn_test_tag(s):
        return f"test-{s}"

    # print(len(train_files), len(test_files)) # 4500 4500

    def build_raw_dataset(_dir: str, files: List[str], fn_tag: Callable[[str], str],
                          num_samples: Optional[int] = None) -> List[Sample]:

        samples = [None] * len(files)

        def f(file_name: str, idx: int):
            _res = file_name_pattern.findall(file_name)
            if len(_res) <= 0:
                # 文件名不符合规定，直接跳过
                return
            file_path = os.path.join(_dir, file_name)
            _label, _ = _res[0]
            _label = int(_label)
            _data = list()
            with open(file_path, "r") as fr:
                for line in fr.readlines():  # 依次读取每行
                    line = line.strip()  # 去掉每行头尾空白
                    _time_str, _stream_type_str = tuple(
                        re.split(r"\s+?", line))
                    _time, _stream_type = (
                        float(_time_str), int(_stream_type_str))
                    # 如果同一时间出现了多次上行（下行）包，则合并计数。
                    if len(_data) > 0 and _data[-1][0] == _time and _data[-1][1] == _stream_type:
                        # _data[-1][1] += _stream_type # tuple 不可以直接改变
                        _t, _v = _data[-1]
                        _data[-1] = (_t, _v+_stream_type)
                    else:
                        _data.append((_time, _stream_type))
            assert len(_data) > 0  # _data 不能为空
            sample = Sample(_data, _label, fn_tag(file_name))
            # samples.append(sample)
            samples[idx] = sample

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            _futures = []
            for i, file_name in enumerate(files):
                if num_samples is not None and i >= num_samples:
                    break
                fut = executor.submit(f, file_name, i)
                _futures.append(fut)
            futures.wait(_futures, return_when=futures.ALL_COMPLETED)
        # for i, file_name in enumerate(files):
        #     f(file_name, i)
        # 去除读取失败的
        samples = [ele for ele in samples if ele is not None]
        return samples

    train_dataset = build_raw_dataset(
        train_dir, train_files, fn_train_tag, num_train)
    test_dataset = build_raw_dataset(
        test_dir, test_files, fn_test_tag, num_test)
    return RawDataSet(train_dataset, test_dataset)


def build_test_func(_validation_data_loader):
    """
    获取测试函数
    :param _validation_data_loader:
    :return:
    """

    # https://github.com/pytorch/examples/blob/master/imagenet/main.py
    def cnt_top1_top5(output, target) -> Tuple[int, int]:
        topk = (1, 5)
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = max(topk)
            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))
            return (
                correct[:1].reshape(-1).float().sum(0, keepdim=True),
                correct[:5].reshape(-1).float().sum(0, keepdim=True)
            )

    def _calc_top1_top5(model) -> Tuple[float, float]:
        use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
        with timeit(logger):
            logger.info("validating")
            num_top1 = 0
            num_top5 = 0
            num_all = 0
            for step, _batch in enumerate(_validation_data_loader):
                for _x, _label in _batch:
                    if use_gpu:
                        _x = _x.cuda()
                        _label = _label.cuda()
                    with torch.no_grad():
                        _x_v = Variable(_x)
                        _label_v = Variable(_label)
                    out = model(_x_v)
                    top1, top5 = cnt_top1_top5(out, _label)
                    num_top1 += top1
                    num_top5 += top5
                    num_all += _label.size(0)
            top1_acc = float(num_top1) / num_all
            top5_acc = float(num_top5) / num_all
            return top1_acc, top5_acc

    return _calc_top1_top5


def train_model(model, learning_rate, batch_size, epochs, dataset: RawDataSet,
                tag=None, validate: bool = True,
                validate_on_training: bool = False, validating_interval: int = 100):
    # _calc_top1_top5 = build_test_func(
    #     dataset.test_data_loader(batch_size, 1)
    # )
    def _calc_top1_top5(model):
        fn = build_test_func(
            dataset.test_data_loader(batch_size, 1)
        )
        return fn(model)

    data_loader = dataset.train_data_loader(batch_size, epochs)

    # 传入 model 必须是 cpu 的
    if tag is not None:
        print(f"model[{tag}] start.")
    use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
    if use_gpu:
        model = model.cuda()
    # 定义loss和optimizer
    criterion = nn.CrossEntropyLoss()
    # optimizer = optim.SGD(_model.parameters(), lr=learning_rate)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    _training_data_loader = data_loader
    for i, _batch in enumerate(_training_data_loader):
        # 向后传播
        optimizer.zero_grad()
        for _x, _label in _batch:
            # cuda
            if use_gpu:
                _x = _x.cuda()
                _label = _label.cuda()

            _x = Variable(_x)
            _label = Variable(_label)
            # 向前传播
            out = model(_x)
            # print(out.size(), _label)
            loss = criterion(out, _label)
            loss.backward()
        optimizer.step()
        # duration = time.time() - t
        # logger.info(f"iter:{i}, duration:{duration}, loss:{loss.item()}")
        if validate_on_training and i % validating_interval == (validating_interval - 1):
            with timeit(logger):
                # # print(model.layers[0].state_dict())
                top1_acc, top5_acc = _calc_top1_top5(model)
                res = {
                    "top1_accuracy": top1_acc,
                    "top5_accuracy": top5_acc,
                }
                print(f"current: {res}")
    if validate:
        # # print(model.layers[0].state_dict())
        top1_acc, top5_acc = _calc_top1_top5(model)
        res = {
            "top1_accuracy": top1_acc,
            "top5_accuracy": top5_acc,
        }
        print(f"last: {res}")
    if tag is not None:
        print(f"model[{tag}] done.")
    # 返回值应该用 cpu，否则报错 (需要远程传输)
    return model.cpu()


# %%
def _test_print_max_t():
    """
    输出数据集里最大的时间量
    """
    logger.info("--- 读取数据 ---")
    with timeit(logger):
        # dataset = read_data(data_dir, num_train=6, num_test=6)
        dataset = read_data(data_dir)
    logger.info("--- 数据读取完成 ---")

    max_t = 0
    for i, ele in enumerate(itertools.chain(dataset.train, dataset.test)):
        t = ele.data[-1][0]
        if t > max_t:
            max_t = t
    print(max_t)
    """
    2021-04-23 23:45:46,330 main_method.py[line:144] INFO --- 读取数据 ---
    2021-04-23 23:45:46,330 timeit.py[line:9] INFO start :2021-04-23 23:45:46.330189
    2021-04-23 23:47:19,302 timeit.py[line:9] INFO end :2021-04-23 23:47:19.302056
    2021-04-23 23:47:19,302 timeit.py[line:9] INFO duration :0:01:32.971867
    2021-04-23 23:47:19,303 main_method.py[line:148] INFO --- 数据读取完成 ---
    45.1567089558
    """


def _test_print_max_len_records():
    """
    输出数据集里最高 record 数
    """
    logger.info("--- 读取数据 ---")
    with timeit(logger):
        # 结果只读取2个数据有出现了无间隔的情况
        # dataset = read_data(data_dir, num_train=1, num_test=1)
        dataset = read_data(data_dir)
    logger.info("--- 数据读取完成 ---")
    max_len = 0
    target_tag = None
    for i, ele in enumerate(itertools.chain(dataset.train, dataset.test)):
        _l = len(ele.data)
        if _l > max_len:
            max_len = _l
            target_tag = ele.tag
    print(max_len, target_tag)
    """
    2021-04-24 00:21:57,579 main_method.py[line:156] INFO --- 读取数据 ---
    2021-04-24 00:21:57,579 timeit.py[line:9] INFO start :2021-04-24 00:21:57.579125
    2021-04-24 00:22:45,025 timeit.py[line:9] INFO end :2021-04-24 00:22:45.025808
    2021-04-24 00:22:45,025 timeit.py[line:9] INFO duration :0:00:47.446683
    2021-04-24 00:22:45,025 main_method.py[line:161] INFO --- 数据读取完成 ---
    24206 train-43-49
    """


def _test_print_min_time_gap(_min_gap: float = 1e-15):
    """
    输出数据集里最小的时间间隔
    """
    logger.info("--- 读取数据 ---")
    with timeit(logger):
        # 结果只读取2个数据有出现了无间隔的情况
        # dataset = read_data(data_dir, num_train=1, num_test=1)
        dataset = read_data(data_dir)
    logger.info("--- 数据读取完成 ---")

    # 不分类
    min_t_gap = 100000
    for i, ele in enumerate(itertools.chain(dataset.train, dataset.test)):
        for j in range(len(ele.data)):
            if j == 0:
                continue
            t_gap = ele.data[j][0] - ele.data[j - 1][0]
            if t_gap < min_t_gap and t_gap >= _min_gap:
                min_t_gap = t_gap
    # 分类
    _up_ls, _down_ls = [], []
    for i, ele in enumerate(itertools.chain(dataset.train, dataset.test)):
        tmp_up, tmp_down = [], []
        for j in range(len(ele.data)):
            if ele.data[j][1] == 1:
                tmp_up.append(ele.data[j][0])
            else:
                tmp_down.append(ele.data[j][0])
        _up_ls.append(tmp_up)
        _down_ls.append(tmp_down)
    _up_min_t_gap, _down_min_t_gap = 100000, 10000
    for i, ele in enumerate(_up_ls):
        for j in range(len(ele)):
            if j == 0:
                continue
            t_gap = ele[j] - ele[j - 1]
            if t_gap < _up_min_t_gap and t_gap >= _min_gap:
                _up_min_t_gap = t_gap
    for i, ele in enumerate(_down_ls):
        for j in range(len(ele)):
            if j == 0:
                continue
            t_gap = ele[j] - ele[j - 1]
            if t_gap < _down_min_t_gap and t_gap >= _min_gap:
                _down_min_t_gap = t_gap
    print("min_t_gap:", min_t_gap)
    print("_up_min_t_gap:", _up_min_t_gap)
    print("_down_min_t_gap:", _down_min_t_gap)
    """
    当 _min_gap 设置为 0 时：
    2021-04-24 00:11:30,399 main_method.py[line:157] INFO --- 读取数据 ---
    2021-04-24 00:11:30,399 timeit.py[line:9] INFO start :2021-04-24 00:11:30.399247
    2021-04-24 00:11:30,429 timeit.py[line:9] INFO end :2021-04-24 00:11:30.429249
    2021-04-24 00:11:30,429 timeit.py[line:9] INFO duration :0:00:00.030002
    2021-04-24 00:11:30,429 main_method.py[line:162] INFO --- 数据读取完成 ---
    min_t_gap: 0.0
    _up_min_t_gap: 0.0
    _down_min_t_gap: 0.0

    当 _min_gap 设置为 1e-15 时：
    2021-04-24 00:16:52,994 main_method.py[line:157] INFO --- 读取数据 ---
    2021-04-24 00:16:52,994 timeit.py[line:9] INFO start :2021-04-24 00:16:52.994738
    2021-04-24 00:17:53,691 timeit.py[line:9] INFO end :2021-04-24 00:17:53.691887
    2021-04-24 00:17:53,691 timeit.py[line:9] INFO duration :0:01:00.697149
    2021-04-24 00:17:53,691 main_method.py[line:162] INFO --- 数据读取完成 ---
    min_t_gap: 5.555972570903123e-09
    _up_min_t_gap: 7.97707144783999e-09
    _down_min_t_gap: 5.555972570903123e-09
    """


# %%
args = get_args()

batch_size = args.batchSize  # 批的大小
learning_rate = args.learningRate  # 学习率
num_epochs = args.numEpochs  # 遍历训练集的次数
data_dir = args.dataDir
save_dir = args.saveDir
n_class = args.nClass


# %%
# logger.info("--- 读取数据 ---")
# with timeit(logger):
#     # dataset = read_data(data_dir, num_train=6, num_test=6)
#     dataset = read_data(data_dir)
# logger.info("--- 数据读取完成 ---")
#
# max_t = 0
# for i, ele in enumerate(itertools.chain(dataset.train, dataset.test)):
#     t = ele.data[-1][0]
#     if t > max_t:
#         max_t = t
#
# print(max_t)

# %%
# _test_print_max_t()
# _test_print_min_time_gap()
# _test_print_max_len_records()


# %%
class MyModule(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        # self.seq = nn.Sequential(
        #     nn.Conv1d(2, 8, 30, 15),
        #     nn.Conv1d(8, 16, 10, 5),
        #     nn.AdaptiveAvgPool1d(100),
        #     nn.Flatten(),
        #     nn.Linear(16 * 100, n_class),
        # )
        self.seq = nn.Sequential(
            nn.Conv1d(2, 16, 5, 3),
            nn.ReLU(),
            nn.Conv1d(16, 16, 5, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(300),
            nn.Flatten(),
            nn.Linear(16 * 300, n_class),
        )
        """
        注意，卷积核大小不应该大于输入数据
        RuntimeError: Calculated padded input size per channel: (9). 
        Kernel size: (10). Kernel size can't be greater than actual input size
        """

    def forward(self, x):
        """
        最短支持 300 records，不足补0
        :param x:
        :return:
        """
        L = 300
        if x.shape[0] < L:
            x = F.pad(x, (0, L - x.shape[1]), "constant", value=0)
        out = self.seq(x)
        return out


# %%
logger.info("--- 读取数据 ---")
with timeit(logger):
    # dataset = read_data(data_dir, num_train=6, num_test=6)
    dataset = read_data(data_dir)
logger.info("--- 数据读取完成 ---")

# %%

model = MyModule(n_class)
train_model(
    model=model,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epochs=num_epochs,
    dataset=dataset,
    tag=None,
    validate=True,
    validate_on_training=True,
    validating_interval=100,
)

print("over")
